{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using a bi-lstm to sort a sequence of integers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import string\n",
    "\n",
    "import mxnet as mx\n",
    "from mxnet import gluon, nd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_num = 999\n",
    "dataset_size = 60000\n",
    "seq_len = 5\n",
    "split = 0.8\n",
    "batch_size = 512\n",
    "ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are getting a dataset of **dataset_size** sequences of integers of length **seq_len** between **0** and **max_num**. We use **split*100%** of them for training and the rest for testing.\n",
    "\n",
    "\n",
    "For example:\n",
    "\n",
    "50 10 200 999 30\n",
    "\n",
    "Should return\n",
    "\n",
    "10 30 50 200 999"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = mx.random.uniform(low=0, high=max_num, shape=(dataset_size, seq_len)).astype('int32').asnumpy()\n",
    "Y = X.copy()\n",
    "Y.sort() #Let's sort X to get the target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input [548, 592, 714, 843, 602]\n",
      "Target [548, 592, 602, 714, 843]\n"
     ]
    }
   ],
   "source": [
    "print(\"Input {}\\nTarget {}\".format(X[0].tolist(), Y[0].tolist()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the purpose of training, we encode the input as characters rather than numbers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0123456789 \n",
      "{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, ' ': 10}\n"
     ]
    }
   ],
   "source": [
    "vocab = string.digits + \" \"\n",
    "print(vocab)\n",
    "vocab_idx = { c:i for i,c in enumerate(vocab)}\n",
    "print(vocab_idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We write a transform that will convert our numbers into text of maximum length **max_len**, and one-hot encode the characters.\n",
    "For example:\n",
    "\n",
    "\"30 10\" corresponding indices are [3, 0, 10, 1, 0]\n",
    "\n",
    "We then one hot encode that and get a matrix representation of our input. We don't need to encode our target as the loss we are going to use support sparse labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum length of the string: 19\n"
     ]
    }
   ],
   "source": [
    "max_len = len(str(max_num))*seq_len+(seq_len-1)\n",
    "print(\"Maximum length of the string: %s\" % max_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transform(x, y):\n",
    "    x_string = ' '.join(map(str, x.tolist()))\n",
    "    x_string_padded = x_string + ' '*(max_len-len(x_string))\n",
    "    x = [vocab_idx[c] for c in x_string_padded]\n",
    "    y_string = ' '.join(map(str, y.tolist()))\n",
    "    y_string_padded = y_string + ' '*(max_len-len(y_string))\n",
    "    y = [vocab_idx[c] for c in y_string_padded]\n",
    "    return mx.nd.one_hot(mx.nd.array(x), len(vocab)), mx.nd.array(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "split_idx = int(split*len(X))\n",
    "train_dataset = gluon.data.ArrayDataset(X[:split_idx], Y[:split_idx]).transform(transform)\n",
    "test_dataset = gluon.data.ArrayDataset(X[split_idx:], Y[split_idx:]).transform(transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input [548 592 714 843 602]\n",
      "Transformed data Input \n",
      "[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
      " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n",
      " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n",
      " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]]\n",
      "<NDArray 19x11 @cpu(0)>\n",
      "Target [548 592 602 714 843]\n",
      "Transformed data Target \n",
      "[ 5.  4.  8. 10.  5.  9.  2. 10.  6.  0.  2. 10.  7.  1.  4. 10.  8.  4.\n",
      "  3.]\n",
      "<NDArray 19 @cpu(0)>\n"
     ]
    }
   ],
   "source": [
    "print(\"Input {}\".format(X[0]))\n",
    "print(\"Transformed data Input {}\".format(train_dataset[0][0]))\n",
    "print(\"Target {}\".format(Y[0]))\n",
    "print(\"Transformed data Target {}\".format(train_dataset[0][1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=20, last_batch='rollover')\n",
    "test_data = gluon.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=5, last_batch='rollover')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = gluon.nn.HybridSequential()\n",
    "with net.name_scope():\n",
    "    net.add(\n",
    "        gluon.rnn.LSTM(hidden_size=128, num_layers=2, layout='NTC', bidirectional=True),\n",
    "        gluon.nn.Dense(len(vocab), flatten=False)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.initialize(mx.init.Xavier(), ctx=ctx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = gluon.loss.SoftmaxCELoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use a learning rate schedule to improve the convergence of the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "schedule = mx.lr_scheduler.FactorScheduler(step=len(train_data)*10, factor=0.75)\n",
    "schedule.base_lr = 0.01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':0.01, 'lr_scheduler':schedule})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [0] Loss: 1.6627886372227823, LR 0.01\n",
      "Epoch [1] Loss: 1.210370733382854, LR 0.01\n",
      "Epoch [2] Loss: 0.9692377131035987, LR 0.01\n",
      "Epoch [3] Loss: 0.7976046623067653, LR 0.01\n",
      "Epoch [4] Loss: 0.5714595343476983, LR 0.01\n",
      "Epoch [5] Loss: 0.4458411196444897, LR 0.01\n",
      "Epoch [6] Loss: 0.36039798817736035, LR 0.01\n",
      "Epoch [7] Loss: 0.32665719377233626, LR 0.01\n",
      "Epoch [8] Loss: 0.262064205702915, LR 0.01\n",
      "Epoch [9] Loss: 0.22285924059279422, LR 0.0075\n",
      "Epoch [10] Loss: 0.19018426854559717, LR 0.0075\n",
      "Epoch [11] Loss: 0.1718730723604243, LR 0.0075\n",
      "Epoch [12] Loss: 0.15736752171670237, LR 0.0075\n",
      "Epoch [13] Loss: 0.14579375246737866, LR 0.0075\n",
      "Epoch [14] Loss: 0.13546599733068587, LR 0.0075\n",
      "Epoch [15] Loss: 0.12490207590955368, LR 0.0075\n",
      "Epoch [16] Loss: 0.11803316300915133, LR 0.0075\n",
      "Epoch [17] Loss: 0.10653189395336395, LR 0.0075\n",
      "Epoch [18] Loss: 0.10514750379197141, LR 0.0075\n",
      "Epoch [19] Loss: 0.09590611559279422, LR 0.005625\n",
      "Epoch [20] Loss: 0.08146028108494256, LR 0.005625\n",
      "Epoch [21] Loss: 0.07707348782965477, LR 0.005625\n",
      "Epoch [22] Loss: 0.07206193436967566, LR 0.005625\n",
      "Epoch [23] Loss: 0.07001185417175293, LR 0.005625\n",
      "Epoch [24] Loss: 0.06797058351578252, LR 0.005625\n",
      "Epoch [25] Loss: 0.0649358110224947, LR 0.005625\n",
      "Epoch [26] Loss: 0.06219124286732775, LR 0.005625\n",
      "Epoch [27] Loss: 0.06075144828634059, LR 0.005625\n",
      "Epoch [28] Loss: 0.05711334495134251, LR 0.005625\n",
      "Epoch [29] Loss: 0.054747099572039666, LR 0.00421875\n",
      "Epoch [30] Loss: 0.0441775271233092, LR 0.00421875\n",
      "Epoch [31] Loss: 0.041551097910454936, LR 0.00421875\n",
      "Epoch [32] Loss: 0.04095017269093503, LR 0.00421875\n",
      "Epoch [33] Loss: 0.04045371045457556, LR 0.00421875\n",
      "Epoch [34] Loss: 0.038867686657195394, LR 0.00421875\n",
      "Epoch [35] Loss: 0.038131744303601854, LR 0.00421875\n",
      "Epoch [36] Loss: 0.039834817250569664, LR 0.00421875\n",
      "Epoch [37] Loss: 0.03669035941996473, LR 0.00421875\n",
      "Epoch [38] Loss: 0.03373505967728635, LR 0.00421875\n",
      "Epoch [39] Loss: 0.03164981273894615, LR 0.0031640625\n",
      "Epoch [40] Loss: 0.025532766055035336, LR 0.0031640625\n",
      "Epoch [41] Loss: 0.022659448867148543, LR 0.0031640625\n",
      "Epoch [42] Loss: 0.02307056112492338, LR 0.0031640625\n",
      "Epoch [43] Loss: 0.02236944056571798, LR 0.0031640625\n",
      "Epoch [44] Loss: 0.022204211963120328, LR 0.0031640625\n",
      "Epoch [45] Loss: 0.02262336903430046, LR 0.0031640625\n",
      "Epoch [46] Loss: 0.02253308448385685, LR 0.0031640625\n",
      "Epoch [47] Loss: 0.025286573044797207, LR 0.0031640625\n",
      "Epoch [48] Loss: 0.02439300988310127, LR 0.0031640625\n",
      "Epoch [49] Loss: 0.017976388018181983, LR 0.002373046875\n",
      "Epoch [50] Loss: 0.014343131095805067, LR 0.002373046875\n",
      "Epoch [51] Loss: 0.013039355582379281, LR 0.002373046875\n",
      "Epoch [52] Loss: 0.011884741885687715, LR 0.002373046875\n",
      "Epoch [53] Loss: 0.011438189668858305, LR 0.002373046875\n",
      "Epoch [54] Loss: 0.011447292693117832, LR 0.002373046875\n",
      "Epoch [55] Loss: 0.014212571560068334, LR 0.002373046875\n",
      "Epoch [56] Loss: 0.019900493724371797, LR 0.002373046875\n",
      "Epoch [57] Loss: 0.02102568301748722, LR 0.002373046875\n",
      "Epoch [58] Loss: 0.01346214400961044, LR 0.002373046875\n",
      "Epoch [59] Loss: 0.010107964911359422, LR 0.0017797851562500002\n",
      "Epoch [60] Loss: 0.008353193600972494, LR 0.0017797851562500002\n",
      "Epoch [61] Loss: 0.007678258292218472, LR 0.0017797851562500002\n",
      "Epoch [62] Loss: 0.007262124660167288, LR 0.0017797851562500002\n",
      "Epoch [63] Loss: 0.00705223578087827, LR 0.0017797851562500002\n",
      "Epoch [64] Loss: 0.006788556293774677, LR 0.0017797851562500002\n",
      "Epoch [65] Loss: 0.006473606571238091, LR 0.0017797851562500002\n",
      "Epoch [66] Loss: 0.006206096486842378, LR 0.0017797851562500002\n",
      "Epoch [67] Loss: 0.00584477313021396, LR 0.0017797851562500002\n",
      "Epoch [68] Loss: 0.005648705267137097, LR 0.0017797851562500002\n",
      "Epoch [69] Loss: 0.006481769871204458, LR 0.0013348388671875003\n",
      "Epoch [70] Loss: 0.008430448618341, LR 0.0013348388671875003\n",
      "Epoch [71] Loss: 0.006877245421105242, LR 0.0013348388671875003\n",
      "Epoch [72] Loss: 0.005671108281740578, LR 0.0013348388671875003\n",
      "Epoch [73] Loss: 0.004832422162624116, LR 0.0013348388671875003\n",
      "Epoch [74] Loss: 0.004441103402604448, LR 0.0013348388671875003\n",
      "Epoch [75] Loss: 0.004216198591475791, LR 0.0013348388671875003\n",
      "Epoch [76] Loss: 0.004041922989711967, LR 0.0013348388671875003\n",
      "Epoch [77] Loss: 0.003937713643337818, LR 0.0013348388671875003\n",
      "Epoch [78] Loss: 0.010251983049068046, LR 0.0013348388671875003\n",
      "Epoch [79] Loss: 0.01829354052848004, LR 0.0010011291503906252\n",
      "Epoch [80] Loss: 0.006723233448561802, LR 0.0010011291503906252\n",
      "Epoch [81] Loss: 0.004397524798170049, LR 0.0010011291503906252\n",
      "Epoch [82] Loss: 0.0038475305476087206, LR 0.0010011291503906252\n",
      "Epoch [83] Loss: 0.003591177945441388, LR 0.0010011291503906252\n",
      "Epoch [84] Loss: 0.003425112014175743, LR 0.0010011291503906252\n",
      "Epoch [85] Loss: 0.0032633850549129728, LR 0.0010011291503906252\n",
      "Epoch [86] Loss: 0.0031762316505959693, LR 0.0010011291503906252\n",
      "Epoch [87] Loss: 0.0030452777096565734, LR 0.0010011291503906252\n",
      "Epoch [88] Loss: 0.002950224184220837, LR 0.0010011291503906252\n",
      "Epoch [89] Loss: 0.002821172171450676, LR 0.0007508468627929689\n",
      "Epoch [90] Loss: 0.002725780961361337, LR 0.0007508468627929689\n",
      "Epoch [91] Loss: 0.002660556359493986, LR 0.0007508468627929689\n",
      "Epoch [92] Loss: 0.0026011724946319414, LR 0.0007508468627929689\n",
      "Epoch [93] Loss: 0.0025355776256703317, LR 0.0007508468627929689\n",
      "Epoch [94] Loss: 0.0024825221997626283, LR 0.0007508468627929689\n",
      "Epoch [95] Loss: 0.0024245587435174497, LR 0.0007508468627929689\n",
      "Epoch [96] Loss: 0.002365282145879602, LR 0.0007508468627929689\n",
      "Epoch [97] Loss: 0.0023112583984719946, LR 0.0007508468627929689\n",
      "Epoch [98] Loss: 0.002257173682780976, LR 0.0007508468627929689\n",
      "Epoch [99] Loss: 0.002162747085094452, LR 0.0005631351470947267\n"
     ]
    }
   ],
   "source": [
    "epochs = 100\n",
    "for e in range(epochs):\n",
    "    epoch_loss = 0.\n",
    "    for i, (data, label) in enumerate(train_data):\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "\n",
    "        with mx.autograd.record():\n",
    "            output = net(data)\n",
    "            l = loss(output, label)\n",
    "\n",
    "        l.backward()\n",
    "        trainer.step(data.shape[0])\n",
    "    \n",
    "        epoch_loss += l.mean()\n",
    "        \n",
    "    print(\"Epoch [{}] Loss: {}, LR {}\".format(e, epoch_loss.asscalar()/(i+1), trainer.learning_rate))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We get a random element from the testing set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = random.randint(0, len(test_data)-1)\n",
    "\n",
    "x_orig = X[split_idx+n]\n",
    "y_orig = Y[split_idx+n]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_pred(x):\n",
    "    x, _ = transform(x, x)\n",
    "    output = net(x.as_in_context(ctx).expand_dims(axis=0))\n",
    "\n",
    "    # Convert output back to string\n",
    "    pred = ''.join([vocab[int(o)] for o in output[0].argmax(axis=1).asnumpy().tolist()])\n",
    "    return pred"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Printing the result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X         611 671 275 871 944\n",
      "Predicted 275 611 671 871 944\n",
      "Label     275 611 671 871 944\n"
     ]
    }
   ],
   "source": [
    "x_ = ' '.join(map(str,x_orig))\n",
    "label = ' '.join(map(str,y_orig))\n",
    "print(\"X         {}\\nPredicted {}\\nLabel     {}\".format(x_, get_pred(x_orig), label))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also pick our own example, and the network manages to sort it without problem:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10 30 130 500 999  \n"
     ]
    }
   ],
   "source": [
    "print(get_pred(np.array([500, 30, 999, 10, 130])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model has even learned to generalize to examples not on the training set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Only four numbers: 105 202 302 501    \n"
     ]
    }
   ],
   "source": [
    "print(\"Only four numbers:\", get_pred(np.array([105, 302, 501, 202])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However we can see it has trouble with other edge cases:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Small digits: 8  0 42 28         \n",
      "Small digits, 6 numbers: 10 0 20 82 71 115  \n"
     ]
    }
   ],
   "source": [
    "print(\"Small digits:\", get_pred(np.array([10, 3, 5, 2, 8])))\n",
    "print(\"Small digits, 6 numbers:\", get_pred(np.array([10, 33, 52, 21, 82, 10])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This could be improved by adjusting the training dataset accordingly"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
